#include "radix.h"
#include <malloc.h>
#include <string.h>
#include <assert.h>
#include <stdio.h>
#include <errno.h>
#include <math.h>
#include <stdlib.h>
#define RAX_DEBUG_MSG
void *raxNotFound = (void *)"rax-not-found-pointer";

void raxDebugShowNode(const char *msg, raxNode *n);

#ifdef RAX_DEBUG_MSG
#define debugf(...)                                          \
    if (raxDebugMsg)                                         \
    {                                                        \
        printf("%s:%s:%d:\t", __FILE__, __func__, __LINE__); \
        printf(__VA_ARGS__);                                 \
        fflush(stdout);                                      \
    }

#define debugnode(msg, n) raxDebugShowNode(msg, n)
#else
#define debugf(...)
#define debugnode(msg, n)
#endif
#define raxPadding(nodesize) ((sizeof(void *) - ((nodesize + 4) % sizeof(void *))) & (sizeof(void *) - 1))

#define raxNodeLastChildPtr(n) ((raxNode **)(((char *)(n)) +           \
                                             raxNodeCurrentLength(n) - \
                                             sizeof(raxNode *) -       \
                                             (((n)->iskey && !(n)->isnull) ? sizeof(void *) : 0)))

/* Return the pointer to the first child pointer. */
#define raxNodeFirstChildPtr(n) ((raxNode **)((n)->data + \
                                              (n)->size + \
                                              raxPadding((n)->size)))

#define raxNodeCurrentLength(n) (                                        \
    sizeof(raxNode) + (n)->size +                                        \
    raxPadding((n)->size) +                                              \
    ((n)->iscompr ? sizeof(raxNode *) : sizeof(raxNode *) * (n)->size) + \
    (((n)->iskey && !(n)->isnull) * sizeof(void *)))

static int raxDebugMsg = 1;

void raxSetDebugMsg(int onoff)
{
    raxDebugMsg = onoff;
}
void raxDebugShowNode(const char *msg, raxNode *n)
{
    if (raxDebugMsg == 0)
        return;
    printf("%s: %p [%.*s] key:%u size:%u children:",
           msg, (void *)n, (int)n->size, (char *)n->data, n->iskey, n->size);
    int numcld = n->iscompr ? 1 : n->size;
    raxNode **cldptr = raxNodeLastChildPtr(n) - (numcld - 1);
    while (numcld--)
    {
        raxNode *child;
        memcpy(&child, cldptr, sizeof(child));
        cldptr++;
        printf("%p ", (void *)child);
    }
    printf("\n");
    fflush(stdout);
}

static inline void raxStackInit(raxStack *ts)
{
    ts->stack = ts->static_items;
    ts->items = 0;
    ts->maxitems = RAX_STACK_STATIC_ITMES;
    ts->oom = 0;
}

static inline int raxStackPush(raxStack *ts, void *ptr)
{
    if (ts->items == ts->maxitems)
    {
        if (ts->stack == ts->static_items)
        {
            if (ts->stack == NULL)
            {
                ts->stack = ts->static_items;
                ts->oom = 1;
                errno = ENOMEM;
                return 0;
            }
            memcpy(ts->stack, ts->static_items, sizeof(void *) * ts->maxitems);
        }
        else
        {
            void **newalloc = realloc(ts->stack, sizeof(void *) * ts->maxitems * 2);
            if (newalloc == NULL)
            {
                ts->oom = 1;
                errno = ENOMEM;
                return 0;
            }
            ts->stack = newalloc;
        }
        ts->maxitems *= 2;
    }
    ts->stack[ts->items] = ptr;
    ts->items++;
    return 1;
}

static inline void *raxStackPop(raxStack *ts)
{
    if (ts->items == 0)
    {
        return NULL;
    }
    ts->items--;
    return ts->stack[ts->items];
}

static inline void *raxStackPeek(raxStack *ts)
{
    if (ts->items == 0)
        return NULL;
    return ts->stack[ts->items - 1];
}

static inline void raxStackFree(raxStack *ts)
{
    if (ts->stack != ts->static_items)
        free(ts->stack);
}

raxNode *raxNewNode(size_t children, int datafield)
{
    size_t nodesize = sizeof(raxNode) * children + raxPadding(children) + sizeof(raxNode *) * children;
    if (datafield)
        nodesize + sizeof(void *);
    raxNode *node = malloc(nodesize);
    if (node == NULL)
        return NULL;
    node->iskey = 0;
    node->isnull = 0;
    node->iscompr = 0;
    node->size = children;
    return node;
}

rax *raxNew(void)
{
    rax *rax = malloc(sizeof(*rax));
    if (rax == NULL)
        return NULL;
    rax->numele = 0;
    rax->numnodes = 1;
    rax->head = raxNewNode(0, 0);
    if (rax->head == NULL)
    {
        free(rax);
        return NULL;
    }
    return rax;
}

raxNode *raxRealocForData(raxNode *n, void *data)
{
    if (data == NULL)
        return n;
    size_t curlen = raxNodeCurrentLength(n);
    return realloc(n, curlen + sizeof(void *));
}

void raxSetData(raxNode *n, void *data)
{
    n->iskey = 1;
    if (data != NULL)
    {
        n->isnull = 0;
        void **ndata = (void **)((char *)n + raxNodeCurrentLength(n) - sizeof(void *));
        memcpy(ndata, &data, sizeof(data));
    }
    else
    {
        n->isnull = 1;
    }
}

void *raxGetData(raxNode *n)
{
    if (n->isnull)
        return NULL;
    void **ndata = (void **)((char *)n + raxNodeCurrentLength(n) - sizeof(void *));
    void *data;
    memcpy(&data, ndata, sizeof(data));
    return data;
}

raxNode *raxAddChild(raxNode *n, unsigned char c, raxNode **childptr, raxNode ***parentlink)
{
    assert(n->iscompr == 0);
    size_t curlen = raxNodeCurrentLength(n);
    n->size++;
    size_t newlen = raxNodeCurrentLength(n);
    n->size--;
    raxNode *child = raxNewNode(0, 0);
    if (child == NULL)
        return NULL;
    raxNode *newn = realloc(n, newlen);
    if (newn == NULL)
    {
        free(child);
        return NULL;
    }
    n = newn;
    int pos;
    for (pos = 0; pos < n->size; pos++)
    {
        if (n->data[pos] > c)
            break;
    }
    unsigned char *src, *dst;
    if (n->iskey && !n->size)
    {
        src = ((unsigned char *)n + curlen - sizeof(void *));
        dst = ((unsigned char *)n + newlen - sizeof(void *));
        memmove(dst, src, sizeof(void *));
    }
    size_t shift = newlen - curlen - sizeof(void *);
    src = n->data + n->size + raxPadding(n->size) + sizeof(raxNode *) * pos;
    memmove(src + shift + sizeof(raxNode *), src, sizeof(raxNode *) * (n->size));
    if (shift)
    {
        src = (unsigned char *)raxNodeFirstChildPtr(n);
        memmove(src + shift, src, sizeof(raxNode *) * pos);
    }
    src = n->data + pos;
    memmove(src + 1, src, n->size - pos);
    n->data[pos] = c;
    n->size++;
    src = (unsigned char *)raxNodeFirstChildPtr(n);
    raxNode **childfield = (raxNode **)(src + sizeof(raxNode *) * pos);
    memcpy(childfield, &child, sizeof(child));
    *childptr = child;
    *parentlink = childfield;
    return n;
}

raxNode *raxCompressNode(raxNode *n, unsigned char *s, size_t len, raxNode **child)
{
    assert(n->size == 0 && n->iscompr == 0);
    void *data = NULL;
    size_t newsize;
    debugf("Compress node:%.*s\n", (int)len, s);
    *child = raxNewNode(0, 0);
    if (*child == NULL)
        return NULL;
    newsize = sizeof(raxNode) + len + raxPadding(len) + sizeof(raxNode *);
    if (n->iskey)
    {
        data = raxGetData(n);
        if (!n->isnull)
            newsize += sizeof(void *);
    }
    printf("realloc(%d,%d)\n", n, newsize);
    raxNode *newn = realloc(n, newsize);
    if (newn == NULL)
    {
        free(*child);
        return NULL;
    }
    n = newn;
    n->iscompr = 1;
    n->size = len;

    // memcpy(n->data, s, len + 1);
    if (errno != -1)
    {
        printf("error:%s\n", strerror(errno));
    }
    if (n->iskey)
        raxSetData(n, data);
    raxNode **childfield = raxNodeLastChildPtr(n);
    memcpy(childfield, child, sizeof(*child));
    return n;
}

static inline size_t raxLowWalk(rax *rax, unsigned char *s, size_t len, raxNode **stopnode, raxNode ***plink, int *splitpos, raxStack *ts)
{
    raxNode *h = rax->head;
    raxNode **parentlink = &rax->head;
    size_t i = 0;
    size_t j = 0;
    while (h->size && i < len)
    {
        debugnode("Lookup current node", h);
        unsigned char *v = h->data;
        if (h->iscompr)
        {
            // 获取相同部分的位置
            for (j = 0; j < h->size && i < len; j++, i++)
            {
                if (v[j] != s[i])
                    break;
            }
            if (j != h->size)
                break;
        }
        else
        {
            for (j = 0; j < h->size; j++)
            {
                if (v[j] == s[i])
                    break;
            }
            if (j == h->size)
                break;
            i++;
        }
        if (ts)
            raxStackPush(ts, h);
        raxNode **children = raxNodeFirstChildPtr(h);
        if (h->iscompr)
            j = 0;
        memcpy(&h, children + j, sizeof(h));
        parentlink = children + j;
        j = 0;
    }
    debugnode("Lookup stop node is", h);
    if (stopnode)
        *stopnode = h;
    if (plink)
        *plink = parentlink;
    if (splitpos && h->iscompr)
        *splitpos = j;
    return i;
}

int raxGenericInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old, int orverwrite)
{
    size_t i;
    int j = 0;
    raxNode *h, **parentlink;
    debugf("### Insert %.*s with value %p\n", (int)len, s, data);
    i = raxLowWalk(rax, s, len, &h, &parentlink, &j, NULL);
    if (i == len && (!h->iscompr || j == 0))
    {
        debugf("### Insert: node representing key exists\n");
        if (!h->iskey || (h->isnull && orverwrite))
        {
            h = raxRealocForData(h, data);
            if (h)
                memcpy(parentlink, &h, sizeof(h));
        }
        if (h == NULL)
        {
            errno = ENOMEM;
            return 0;
        }
        if (h->iskey)
        {
            if (old)
                *old = raxGetData(h);
            if (orverwrite)
                raxSetData(h, data);
            errno = 0;
            return 0;
        }
        raxSetData(h, data);
        rax->numele++;
        return 1;
    }
    if (h->iscompr && i != len)
    {
        debugf("ALGO 1:Stopped at compressed node %.*s (%p)\n", h->size, h->data, (void *)h);
        debugf("Still to insert: %.*s\n", (int)(len - i), s + i);
        debugf("Splitting at %d: '%c'\n", j, ((char *)h->data)[j]);
        debugf("Other (key) letter is '%c'\n", s[i]);
        /* 1:Save next pointer. */
        raxNode **childfield = raxNodeLastChildPtr(h);
        raxNode *next;
        memcpy(&next, childfield, sizeof(next));
        debugf("Next is %p\n", (void *)next);
        debugf("iskey %d\n", h->iskey);
        if (h->iskey)
        {
            debugf("key value is %p\n", raxGetData(h));
        }
        size_t trimmedlen = j;
        size_t postfixlen = h->size - j - 1;
        int split_node_is_key = !trimmedlen && h->iskey && !h->isnull;
        size_t nodesize;

        /* Creat the split node. Also allocate the other nodes we'll need ASAP, so that it will be simpler to handle OOM. */
        raxNode *splitnode = raxNewNode(1, split_node_is_key);
        raxNode *trimmed = NULL;
        raxNode *postfix = NULL;
        if (trimmedlen)
        {
            nodesize = sizeof(raxNode) + trimmedlen + raxPadding(trimmedlen) + sizeof(raxNode *);
            if (h->iskey && !h->isnull)
                nodesize += sizeof(void *);
            trimmed = malloc(nodesize);
        }
        if (postfixlen)
        {
            nodesize = sizeof(raxNode) + postfixlen + raxPadding(postfixlen) + sizeof(raxNode *);
            postfix = malloc(nodesize);
        }
        if (splitnode == NULL || (trimmedlen && trimmed == NULL) || (postfixlen && postfix == NULL))
        {
            free(splitnode);
            free(trimmed);
            free(postfix);
            errno = ENOMEM;
            return 0;
        }
        splitnode->data[0] = h->data[0];
        if (j == 0)
        {
            if (h->iskey)
            {
                void *ndata = raxGetData(h);
                raxSetData(splitnode, ndata);
            }
            memcpy(parentlink, &splitnode, sizeof(splitnode));
        }
        else
        {
            trimmed->size = j;
            memcpy(trimmed->data, &splitnode, sizeof(splitnode));
            trimmed->iscompr = j > 1 ? 1 : 0;
            trimmed->iskey = h->iskey;
            trimmed->isnull = h->isnull;
            if (h->iskey && !h->isnull)
            {
                void *ndata = raxGetData(h);
                raxSetData(trimmed, ndata);
            }
            raxNode **cp = raxNodeLastChildPtr(trimmed);
            memcpy(cp, &splitnode, sizeof(splitnode));
            memcpy(parentlink, &trimmed, sizeof(trimmed));
            parentlink = cp;
            rax->numele++;
        }
        /* Craete the postfix node: what remain of the original compressed node after the split. */
        if (postfixlen)
        {
            postfix->iskey = 0;
            postfix->isnull = 0;
            postfix->size = postfixlen;
            postfix->iscompr = postfixlen > 1;
            memcpy(postfix->data, h->data + j + i, postfixlen);
            raxNode **cp = raxNodeLastChildPtr(postfix);
            memcpy(cp, &next, sizeof(next));
            rax->numnodes++;
        }
        else
        {
            postfix = next;
        }
        /* Set splitnode first child as the postfix node.*/
        raxNode **splitchild = raxNodeLastChildPtr(splitnode);
        memcpy(splitchild, &postfix, sizeof(postfix));
        if (h != NULL)
            free(h);
        h = splitnode;
    }
    else if (h->iscompr && i == len)
    {
        debugf("ALGO 2: stopped at comressed node %.*s (%p) j = %d\n", h->size, h->data, (void *)h, j);
        size_t postfixlen = h->size - j;
        size_t nodesize = sizeof(raxNode) + postfixlen + raxPadding(postfixlen) + sizeof(raxNode *);

        if (data != NULL)
            nodesize += sizeof(void *);
        raxNode *postfix = malloc(nodesize);
        nodesize = sizeof(raxNode) + j + raxPadding(j) + sizeof(raxNode *);
        if (h->iskey && !h->isnull)
            nodesize += sizeof(void *);
        raxNode *trimmed = malloc(nodesize);
        if (postfix == NULL || trimmed == NULL)
        {
            free(postfix);
            free(trimmed);
            errno = ENOMEM;
            return 0;
        }
        /* Save next pointer*/
        raxNode **childfield = raxNodeLastChildPtr(h);
        raxNode *next;
        memcpy(&next, childfield, sizeof(next));
        /* Create the postfix node*/
        postfix->size = postfixlen;
        postfix->iscompr = postfixlen > 1;
        postfix->iskey = 1;
        postfix->isnull = 0;
        memcpy(postfix->data, h->data + j, postfixlen);
        raxSetData(postfix, data);
        raxNode **cp = raxNodeLastChildPtr(postfix);
        memcpy(cp, &next, sizeof(next));
        rax->numnodes++;

        /* Trim the compressed node. */
        trimmed->size = j;
        trimmed->iscompr = j > 1;
        trimmed->iskey = 0;
        trimmed->isnull = 0;
        memcpy(trimmed->data, h->data, j);
        memcpy(parentlink, &trimmed, sizeof(trimmed));
        if (h->iskey)
        {
            void *aux = raxGetData(h);
            raxSetData(trimmed, aux);
        }
        /* Fix the trimmed node child pointer to point to the postfix node. */
        cp = raxNodeLastChildPtr(trimmed);
        memcpy(cp, &postfix, sizeof(postfix));

        /* Finish! we dont need to contine with the insertion algorithm for ALGO 2. The key is alread insertd.*/
        rax->numele++;
        free(h);
        return 1;
    }
    /* We walked the radix tree as far as we could, but still there are left chars in our string. We need to insert the missing nodes. */
    while (i < len)
    {
        raxNode *child;
        if (h->size == 0 && len - i > 1)
        {
            debugf("Inserting compressed node\n");
            size_t comprsize = len - i;
            if (comprsize > RAX_NODE_MAX_SIZE)
                comprsize = RAX_NODE_MAX_SIZE;
            raxNode *newh = raxCompressNode(h, s + i, comprsize, &child);
            if (newh == NULL)
                goto oom;
            h = newh;
            memcpy(parentlink, &h, sizeof(h));
            parentlink = raxNodeLastChildPtr(h);
            i += comprsize;
        }
        else
        {
            debugf("Inserting normal node\n");
            raxNode **new_parentlink;
            raxNode *newh = raxAddChild(h, s[i], &child, &new_parentlink);
            if (newh == NULL)
                goto oom;
            h = newh;
            memcpy(parentlink, &h, sizeof(h));
            parentlink = new_parentlink;
            i++;
        }
        rax->numnodes++;
        h = child;
    }
    raxNode *newh = raxRealocForData(h, data);
    if (newh == NULL)
        goto oom;
    h = newh;
    if (!h->iskey)
        rax->numele++;
    raxSetData(h, data);
    memcpy(parentlink, &h, sizeof(h));

    return 1;
oom:
    if (h->size == 0)
    {
        h->isnull = 1;
        h->iskey = 1;
        rax->numele++;
        // assert(raxRemove(rax, s, i, NULL) != 0);
    }
    errno = ENOMEM;
    return 0;
}

int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old)
{
    return raxGenericInsert(rax, s, len, data, old, 1);
}
